import torch
import random
import numpy as np
import os
import os
import os.path

from torchvision import datasets, transforms
from util.dataset_class.Cub2011 import Cub2011
from util.dataset_class.Cifar100SubClass import Cifar100SubClass
from util.dataset_class.ImageNetSubClass import ImageNetSubClass


def make_state_id(args):
    state_list = []
    if args.additional_info is not None:
        state_list += [args.additional_info]

    state_list += [args.arch]

    state_list += [args.activation_type, args.operation_order, 'seed' + str(args.seed), args.dataset]

    if args.no_gamma_decay or args.no_beta_decay or args.no_weight_decay:
        decay_str = 'no'
        if args.no_weight_decay:
            decay_str += 'Weight'
        if args.no_gamma_decay:
            decay_str += 'Gamma'
        if args.no_beta_decay:
            decay_str += 'Beta'
        decay_str += 'Decay'

        state_list += [decay_str]

    elif args.weight_decay != args.gamma_decay or args.gamma_decay != args.beta_decay or args.weight_decay != args.beta_decay:
        decay_str = 'weight' + str(args.weight_decay)
        decay_str += 'gamma' + str(args.gamma_decay)
        decay_str += 'beta' + str(args.beta_decay)

        state_list += [decay_str]
    else:
        state_list += [str(args.weight_decay)]

    state_list += [str(args.lr)]

    model_filename = '_'.join(state_list)

    return model_filename

def fix_randomness(seed, cuda_available):
    torch.random.manual_seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    if cuda_available:
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False 

def get_data_loader(dataset, batch_size, test_batch_size):
    if os.path.isdir('/data'):
        data_dir = '/data'
    else:
        data_dir = './data'

    if dataset == 'SVHN':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

        data_path = os.path.join(data_dir, 'svhn')

        train_data = datasets.SVHN(root=data_path, split='train', download=True, transform=transform)
        test_data = datasets.SVHN(root=data_path, split='test', download=True, transform=transform)

        train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False)

        num_classes = 10
    elif dataset == 'shapeset':
        transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.ToTensor(),
        ])
        #transforms.Lambda(lambda x: (x*2/255)-1)
        #])
        data_path = os.path.join(data_dir, 'shapeset')

        train_data = datasets.ImageFolder(os.path.join(data_path, 'train'), transform)
        test_data = datasets.ImageFolder(os.path.join(data_path, 'valid'), transform)

        train_loader=torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=32, pin_memory=True)
        test_loader=torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, num_workers=32, pin_memory=True)

        num_classes = 9
    elif dataset == 'MNIST':
        mnist_transform = transforms.Compose([
            transforms.ToTensor(), 
            transforms.Normalize((0.5,), (1.0,))
        ])

        data_path = os.path.join(data_dir, 'mnist')
        os.makedirs(data_path, exist_ok=True)

        train_data = datasets.MNIST(data_path, transform=mnist_transform, train=True, download=True)
        test_data = datasets.MNIST(data_path, transform=mnist_transform, train=False, download=True)

        train_loader = torch.utils.data.DataLoader(dataset=train_data, 
                         batch_size=batch_size,
                         shuffle=True)

        test_loader = torch.utils.data.DataLoader(dataset=test_data, 
                         batch_size=batch_size,
                         shuffle=False)

        num_classes=10
        
    elif dataset == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        data_path = os.path.join(data_dir, 'cifar10')


        train_data = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform_train)
        test_data = datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform_test)

        train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True,
                                                   num_workers=2)
        test_loader = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False,
                                                  num_workers=2)

        num_classes = 10
    elif dataset == 'cifar100':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        data_path = os.path.join(data_dir, 'cifar100')

        train_data = datasets.CIFAR100(root=data_path, train=True, download=True, transform=transform_train)
        test_data = datasets.CIFAR100(root=data_path, train=False, download=True, transform=transform_test)

        train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True,
                                                   num_workers=2)
        test_loader = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False,
                                                  num_workers=2)

        num_classes = 100

    elif 'cifar100_1' in dataset:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        data_path = os.path.join(data_dir, 'cifar100')

        set_default_class = len(dataset.split('_')) == 2

        if set_default_class:
            target_class = 11
        else:
            target_class = int(dataset.split('_')[-1])

        exclude_list = [c for c in range(100) if c != target_class]

        train_data = Cifar100SubClass(root=data_path, train=True, download=True, transform=transform_train, exclude_list=exclude_list)
        test_data = Cifar100SubClass(root=data_path, train=False, download=True, transform=transform_test, exclude_list=exclude_list)

        train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
        test_loader = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, num_workers=2)

        num_classes = 100

    elif dataset == 'FashionMNIST':
        transform = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize((0.5,), (0.5,))])
        data_path = os.path.join(data_dir, 'fashionmnist')

        train_data = datasets.FashionMNIST(data_path, train=True, download=True, transform=transform)
        test_data = datasets.FashionMNIST(data_path, train=False, download=True, transform=transform)

        train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
        test_loader = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, num_workers=2)

        num_classes = 10
    elif dataset == 'tinyImageNet':
        transform_train = transforms.Compose([
            transforms.RandomCrop(64, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2302, 0.2265, 0.2262))
        ])
        transform_test = transforms.Compose([
            transforms.CenterCrop(64),
            transforms.ToTensor(),
            transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2302, 0.2265, 0.2262))
        ])
        data_path = os.path.join(data_dir, 'tiny-imagenet')

        train_data = datasets.ImageFolder(os.path.join(data_path, 'train'), transform_train)
        test_data = datasets.ImageFolder(os.path.join(data_path, 'val'), transform_test)

        train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=True)
        test_loader = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, pin_memory=True)

        num_classes = 200
    elif 'ImageNet' in dataset :
        # transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]), ])
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]), ])
        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]), ])

        data_path = os.path.join(data_dir, 'imagenet')

        if dataset != 'ImageNet':
            import re
            classes = re.sub(r'[^0-9]', '', dataset)
            num_classes = int(classes)

            train_data = ImageNetSubClass(root=os.path.join(data_path, 'train'), transform=transform_train, class_num=num_classes)
            test_data = ImageNetSubClass(root=os.path.join(data_path, 'val3'), transform=transform_test, class_num=num_classes)

        else:
            num_classes = 1000

            train_data = datasets.ImageFolder(root=os.path.join(data_path, 'train'), transform=transform_train)
            test_data = datasets.ImageFolder(root=os.path.join(data_path, 'val3'), transform=transform_test)

        train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
        test_loader = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, num_workers=8, pin_memory=True)
            
    elif dataset == 'cub200':
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        transform_train = \
            transforms.Compose([
                transforms.Resize(256),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])
        data_path = os.path.join(data_dir, 'cub200_2011')

        train_data = Cub2011(os.path.join(data_path), train=True, transform=transform_train)
        test_data = Cub2011(os.path.join(data_path), train=False, transform=transform_test)

        train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True,
                                                   num_workers=4, pin_memory=True)
        test_loader = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False,
                                                  num_workers=4, pin_memory=True)

        num_classes = 200
    elif dataset == 'flower102':
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        transform_train = \
            transforms.Compose([
                transforms.Resize(256),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])
        data_path = os.path.join(data_dir, 'flower102')

        train_data = datasets.ImageFolder(root=os.path.join(data_path, 'flower_data', 'train'), transform=transform_train)
        test_data = datasets.ImageFolder(root=os.path.join(data_path, 'flower_data', 'valid'), transform=transform_test)

        train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True,
                                                   num_workers=4, pin_memory=True)
        test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False,
                                                  num_workers=4, pin_memory=True)
        num_classes = 102
    else:
        print('Invalid dataset!')
        exit(0)

    return train_loader, test_loader, num_classes

def get_pretrained_weight(model, args):

    pretrained_model = torch.load(args.pretrained, map_location="cuda" if args.cuda else "cpu")

    best_acc = pretrained_model['acc']
    pretrained_weight = pretrained_model['state_dict']
    model_weight = model.state_dict()

    model_weight_key = list(model.state_dict().keys())
    pretrained_weight_key = list(pretrained_weight.keys())

    for ind, key in enumerate(pretrained_weight.keys()):
        weight = pretrained_weight[key]
        if key not in model_weight_key:
            pair_key = model_weight_key[ind]

            model_weight[pair_key] = weight
        else:
            model_weight[key] = weight


    # Use when layers were pushed back
    try:
        model.load_state_dict(pretrained_model['state_dict'])
    except:
        print("layer adjusting..")
        model.load_state_dict(model_weight)

    print("model accuracy: ", best_acc)

    return pretrained_model

def save_state(model, acc, model_filename):
    print('==> Saving model ...')
    state = {
        'acc': acc,
        'state_dict': model.state_dict(),
    }
    key_list = state['state_dict'].copy().keys()
    for key in key_list:
        if 'module' in key:
            state['state_dict'][key.replace('module.', '')] = \
                state['state_dict'].pop(key)

    torch.save(state, os.path.join('saved_models/', model_filename))

def gini_coefficient(x):
    """Compute Gini coefficient of array of values"""
    diffsum = 0
    x = np.sort(x)

    for i, xi in enumerate(x[:-1], 1):
        diffsum += np.sum(np.abs(xi - x[i:]))
    return diffsum / (len(x)**2 * np.mean(x))

def l1_reg(model, l1_reg_target):
    reg_loss = torch.tensor(0., requires_grad=True)
    param_index = 0
    block_param_num = 0

    for name, param in model.named_parameters():
        if len(param.shape) == 4:  # conv: non bias로 conv, gamma, beta가 한 그룹
            block_param_num = 3
        elif len(param.shape) == 2:  # fc: bias 존재로 linear, bias, gamma, beta가 한 그룹
            block_param_num = 4

        if ((block_param_num == 3 and param_index == 1) or
                (block_param_num == 4 and param_index == 2)):
            if l1_reg_target['gamma']:
                reg_loss = reg_loss + torch.norm(param, 1)

        elif ((block_param_num == 3 and param_index == 2) or
              (block_param_num == 4 and param_index == 3)):
            if l1_reg_target['beta']:
                reg_loss = reg_loss + torch.norm(param, 1)

        param_index += 1
        param_index %= block_param_num

    return reg_loss

def print_model_parameters(model):
    # print the number of model parameters
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print('Total parameter number:', params, '\n')

